import logging
from typing import Callable, Dict, List, Optional, Union

import torch
import torchmetrics
from torch import nn
from torch.nn.modules.loss import _Loss

from ..constants import FUSION_NER, LOGITS, NER_TEXT, WEIGHT
from ..data.mixup import MixupModule
from .lit_module import LitModule

logger = logging.getLogger(__name__)


class NerLitModule(LitModule):
    """
    Control the loops for training, evaluation, and prediction of Named Entity Recognition. This module is independent of
    the model definition. This class inherits from Lightning's LightningModule:
    https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
    """

    def __init__(
        self,
        model: nn.Module,
        optim_type: Optional[str] = None,
        lr_choice: Optional[str] = None,
        lr_schedule: Optional[str] = None,
        lr: Optional[float] = None,
        lr_decay: Optional[float] = None,
        end_lr: Optional[Union[float, int]] = None,
        lr_mult: Optional[Union[float, int]] = None,
        weight_decay: Optional[float] = None,
        warmup_steps: Optional[int] = None,
        loss_func: Optional[_Loss] = None,
        validation_metric: Optional[torchmetrics.Metric] = None,
        validation_metric_name: Optional[str] = None,
        custom_metric_func: Callable = None,
        test_metric: Optional[torchmetrics.Metric] = None,
        peft: Optional[str] = None,
        trainable_param_names: Optional[List] = None,
        mixup_fn: Optional[MixupModule] = None,
        mixup_off_epoch: Optional[int] = 0,
        model_postprocess_fn: Callable = None,
        skip_final_val: Optional[bool] = False,
        track_grad_norm: Optional[Union[int, str]] = -1,
    ):
        """
        Parameters
        ----------
        model
            A Pytorch model
        optim_type
            Optimizer type. We now support:
            - adamw
            - adam
            - sgd
        lr_choice
            How to set each layer's learning rate. If not specified, the default is a single
            learnng rate for all layers. Otherwise, we now support two choices:
            - two_stages
                The layers in the pretrained models have a small learning rate (lr * lr_mult),
                while the newly added head layers use the provided learning rate.
            - layerwise_decay
                The layers have decreasing learning rate from the output end to the input end.
                The intuition is that later layers are more task-related, hence larger learning rates.
        lr_schedule
            Learning rate schedule. We now support:
            - cosine_decay
                Linear warmup followed by cosine decay
            - polynomial_decay
                Linear warmup followed by polynomial decay
        lr
            Learning rate.
        lr_decay
            The learning rate decay factor (0, 1). It is used only when lr_choice is "layerwise_decay".
        end_lr
            The final learning rate after decay.
        lr_mult
            The learning rate multiplier (0, 1). It is used only when lr_choice is "two_stages".
        weight_decay
            The weight decay to regularize layer weights' l2 norm.
        warmup_steps
            How many steps to warmup learning rate. If a float (0, 1), it would represent the
            percentage of steps over all the training steps. The actual number is calculated as
            "int(warmup_steps * max_steps)". If an integer, it would be the exact step number.
        loss_func
            A Pytorch loss module, e.g., nn.CrossEntropyLoss().
        validation_metric
            A torchmetrics module used in the validation stage, e.g., torchmetrics.Accuracy().
        validation_metric_name
            Name of validation metric in case that validation_metric is a aggregation metric,
            e.g., torchmetrics.MeanMetric, whose name can't reflect the real metric name.
        custom_metric_func
            A customized metric function in case that torchmetrics doesn't have the metric.
            It is generally used together with torchmetrics' aggregators, e.g., torchmetrics.MeanMetric.
            Refer to https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/aggregation.py
        test_metric
            A torchmetrics module used in the test stage, e.g., torchmetrics.Accuracy().
        peft
            Whether to use efficient finetuning strategies. This will be helpful for fast finetuning of large backbones.
            We support options such as:

            - bit_fit (only finetune the bias terms)
            - norm_fit (only finetune the weights in norm layers / bias layer)
            - lora, lora_bias, lora_norm (only finetunes decomposition matrices inserted into model, in combination with either bit_fit or norm_fit)
            - ia3, ia3_bias, ia3_norm (adds vector that scales activations by learned vectors, in combination with either bit_fit or norm_fit)
            - None (do not use efficient finetuning strategies)
        track_grad_norm
            Track the p-norm of gradients during training. May be set to ‘inf’ infinity-norm.
            If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them.

        """
        super().__init__(
            model=model,
            optim_type=optim_type,
            lr_choice=lr_choice,
            lr_schedule=lr_schedule,
            lr=lr,
            lr_decay=lr_decay,
            end_lr=end_lr,
            lr_mult=lr_mult,
            weight_decay=weight_decay,
            warmup_steps=warmup_steps,
            loss_func=loss_func,
            validation_metric=validation_metric,
            validation_metric_name=validation_metric_name,
            custom_metric_func=custom_metric_func,
            test_metric=test_metric,
            peft=peft,
            trainable_param_names=trainable_param_names,
            mixup_fn=mixup_fn,
            mixup_off_epoch=mixup_off_epoch,
            model_postprocess_fn=model_postprocess_fn,
            skip_final_val=skip_final_val,
            track_grad_norm=track_grad_norm,
        )

    def _compute_loss(
        self,
        output: Dict,
        label: torch.Tensor,
    ):
        loss = 0
        for prefix, per_output in output.items():
            if prefix == NER_TEXT or prefix == FUSION_NER:
                weight = per_output[WEIGHT] if WEIGHT in per_output else 1
                active_loss = label.view(-1) != 0
                active_logits = per_output[LOGITS].view(-1, self.model.num_classes)[active_loss]
                active_labels = label.view(-1)[active_loss]
                loss += (
                    self.loss_func(
                        input=active_logits,
                        target=active_labels,
                    )
                    * weight
                )
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Per validation step. This function is registered by LightningModule.
        Refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation

        Parameters
        ----------
        batch
            A dictionary containing the mini-batch data, including both input data and
            ground-truth labels. The mini-batch data are passed to each individual model,
            which indexes its required input data by keys with its model prefix. The
            ground-truth labels are used here to compute the validation loss and metric.
            The validation metric is used for top k model selection and early stopping.
        batch_idx
            Index of mini-batch.
        """
        output, loss = self._shared_step(batch)
        if self.model_postprocess_fn:
            output = self.model_postprocess_fn(output)

        # By default, on_step=False and on_epoch=True
        self.log("val_loss", loss)
        label = batch[self.model.label_key]
        logits = output[self.model.prefix][LOGITS]
        active_loss = label.view(-1) != 0
        active_logits = logits.view(-1, self.model.num_classes)[active_loss]
        active_labels = label.view(-1)[active_loss]
        self._compute_metric_score(
            metric=self.validation_metric,
            custom_metric_func=self.custom_metric_func,
            logits=active_logits,
            label=active_labels,
        )
        self.log(
            self.validation_metric_name,
            self.validation_metric,
            on_step=False,
            on_epoch=True,
        )
